import torch
import torch.nn as nn
import torch.nn.functional as F

GRAD_CLIP = 0.01


class GradClip(torch.autograd.Function):

  @staticmethod
  def forward(ctx, x):
    return x

  @staticmethod
  def backward(ctx, grad_x):
    o = torch.zeros_like(grad_x)
    grad_x = torch.where(grad_x.abs() > GRAD_CLIP, o, grad_x)
    grad_x = torch.where(torch.isnan(grad_x), o, grad_x)
    return grad_x


class GradientClip(nn.Module):

  def __init__(self):
    super(GradientClip, self).__init__()

  def forward(self, x):
    return GradClip.apply(x)
